import os
import pandas as pd
import numpy as np
import logging
LOGGER = logging.getLogger(__name__)
from init import PATHS
from C_PatentVariables import conventions_names_colors
dict_subsector_shortnames = conventions_names_colors.get_dict_subsector_shortnames()
dict_subsector_shortnames = dict([(k.lower(), v) for k, v in dict_subsector_shortnames.items()])



def calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, name, var):
    # name = ''
    # var = 'Count' or var = 'Stock'
    df_firmyear['{}{}'.format(var, name)] = df_firm_fam_info[mask].groupby(['firm_agg_id', 'year_agg'])['weight_discounted'].sum()
    # For citation-weighted counts: calculate the average_nbr_citations for every year
    # this will allow us to scale back the weighted coutn measure such that
    # once aggregated at the year level, the count measure and the weighted count are the same.
    # i.e. the weighted count measure redistributes the weights within year across families
    # So, next, calculate average_nbr_citations in year t
    average_nbr_citations = df_firm_fam_info[mask][['earliest_filing_year', 'docdb_family_id', 'nb_citing_docdbfam_w3y']].drop_duplicates()
    average_nbr_citations = average_nbr_citations.groupby('earliest_filing_year')['nb_citing_docdbfam_w3y'].mean().rename('cit_yearly_mean').reset_index()
    df_firm_fam_info = df_firm_fam_info.merge(average_nbr_citations, on='earliest_filing_year', how='left')
    # Weigthing down families that have received fewer citations than average in that year.
    df_firm_fam_info['weighted_count'] = df_firm_fam_info['nb_citing_docdbfam_w3y'] / df_firm_fam_info['cit_yearly_mean']
    # For knowledge stocks: discount families that were filed many years ago.
    df_firm_fam_info['weighted_count_discounted'] = df_firm_fam_info['weight_discounted'] * df_firm_fam_info['weighted_count']
    df_firmyear['{}Cit3y{}'.format(var, name)] = df_firm_fam_info[mask].groupby(['firm_agg_id', 'year_agg'])['weighted_count_discounted'].sum()
    return df_firmyear


def baseline_counts(df_firm_fam_info, var):
    LOGGER.info('               start baseline_counts')
    # For citation-weighted counts: add +1 to nbr of citations for each family
    # otherwise patents with no citations will appear as zero and won't be distinguishable from observations where there are no patents.
    df_firm_fam_info['nb_citing_docdbfam_w3y'] = df_firm_fam_info['nb_citing_docdbfam_w3y'].fillna(0)
    df_firm_fam_info['nb_citing_docdbfam_w3y'] = df_firm_fam_info['nb_citing_docdbfam_w3y'] + 1
    # start firm_year structure with calculating stats about size of families
    df_firmyear = pd.DataFrame(df_firm_fam_info.groupby(['firm_agg_id', 'year_agg'])['docdb_family_size'].mean().rename('MeanFamSize'))
    # Calculate simple count and weighted count
    mask = df_firm_fam_info['docdb_family_id'].notnull()   # basically = all true
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '', var)
    # Calculate counts of families filed in at least 2 countries
    mask = df_firm_fam_info['NbrCountries'] > 1
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_2Ctry', var)
    # Calculate counts of families filed in at least 1 OECD country
    mask = df_firm_fam_info['NbrOECDcountries'] > 0
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_1OECD', var)
    # Calculate counts of families are triadic
    mask = df_firm_fam_info['ifTriadic'] == True
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_Triad', var)
    df_firmyear = df_firmyear.fillna(0).astype(int)
    return df_firmyear


def counting_families(df_firm_year_docdbid, namefile, var):
    LOGGER.info('           start counting_families')
    # Import family-level information to aggregate at firm-year level
    useful_columns = ['docdb_family_id', 'earliest_filing_year', 'ifTriadic', 'NbrCountries', 'NbrOECDcountries', 'Granted', 'nb_citing_docdbfam_w3y', 'docdb_family_size']
    if 'suppliers' in namefile:
        df_fam_info = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_suppliers.csv', usecols=useful_columns)
    elif (namefile == 'ipc_cpc_transpo') or (namefile == 'OtherFirms_in_ipc_cpc_transpo'):
        df_fam_info = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_ipc_cpc_transpo.csv', usecols=useful_columns)
    elif namefile == 'nace_motor':
        df_fam_info = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_nace_motor.csv', usecols=useful_columns)
    else:
        df_fam_info = []
        df_fam_info.append(pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_OEMs.csv', usecols=useful_columns))
        df_fam_info.append(pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_subsidiaries.csv', usecols=useful_columns))
        df_fam_info = pd.concat(df_fam_info).drop_duplicates()
    # Create column weight_discounted if it does not exist. This is a trick to conveniently calculate within the same function
    # Either the normal simple counts as well as the knowledge stocks
    if 'weight_discounted' not in df_firm_year_docdbid.columns:
        df_firm_year_docdbid['weight_discounted'] = 1
    # Make sure docdb ids are not appearing several times per firm-year (else would be double-counted)
    df_firm_year_docdbid = df_firm_year_docdbid.groupby(['year_agg', 'firm_agg_id', 'docdb_family_id']).max().reset_index()
    # Add docdb ids info to firm identifiers
    df_firm_fam_info = df_firm_year_docdbid.merge(df_fam_info, on='docdb_family_id', how='left')
    df_firmyear = baseline_counts(df_firm_fam_info, var)
    df_firmyear = add_counts_by_type(df_firmyear, df_firm_fam_info, namefile, var)
    df_firmyear = drop_uninformative_columns_for_ipc_cpc_transpo(df_firmyear, namefile)
    return df_firmyear


def drop_uninformative_columns_for_ipc_cpc_transpo(df_firmyear, namefile):
    if (namefile == 'ipc_cpc_transpo'):
        # for ipc_cpc_transpo: drop any var weighted 3yr cit, Count_noCPCIPC  Count_otherCPCIPC, _NotTransport
        # those fields are uninformative by definition
        # the 3y citation weighted counts will be exactly equal to the normal count cause the weight will be = 1
        # this is the case because the weights are constructed within years.
        # and since here we're not aggregating at an firm id - year but just at year level, the weight sums up to 1
        # for Count_noCPCIPC  Count_otherCPCIPC, _NotTransport: those will all be 0 since by definition we're looking at families in Transpo
        cols_to_drop1 = [k for k in df_firmyear.columns if '3y' in k]
        cols_to_drop2 = [k for k in df_firmyear.columns if 'NotTransport' in k]
        cols_to_drop3 = [k for k in df_firmyear.columns if 'noCPCIPC' in k]
        cols_to_drop4 = [k for k in df_firmyear.columns if 'otherCPCIPC' in k]
        cols_to_drop = cols_to_drop1 + cols_to_drop2 + cols_to_drop3 + cols_to_drop4
        df_firmyear = df_firmyear.drop(columns=cols_to_drop)
    return df_firmyear



def add_counts_by_type(df_firmyear, df_firm_fam_info, namefile, var):
    # var = 'count'
    LOGGER.info('               start add_counts_by_type')
    # Import family-level information with types to aggregate at firm-year level
    if 'suppliers' in namefile:
        df_fam_codes = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_suppliers_cpc_ipc.csv')
    elif (namefile == 'ipc_cpc_transpo') or (namefile == 'OtherFirms_in_ipc_cpc_transpo'):
        df_fam_codes = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_ipc_cpc_transpo_cpc_ipc.csv')
        # df_fam_codes['BatTrans'].value_counts()
    elif namefile == 'nace_motor':
        df_fam_codes = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_nace_motor_cpc_ipc.csv')
    else:
        df_fam_codes = []
        df_fam_codes.append(pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_OEMs_cpc_ipc.csv'))
        df_fam_codes.append(pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_subsidiaries_cpc_ipc.csv'))
        df_fam_codes = pd.concat(df_fam_codes).drop_duplicates()
    df_firm_fam_info = df_firm_fam_info.merge(df_fam_codes, on='docdb_family_id', how='left')
    # Aggregating
    # Count number of families that have no IPC or CPC codes information
    mask = df_firm_fam_info['CPC'].isnull() & df_firm_fam_info['IPC'].isnull()
    df_firmyear['{}_noCPCIPC'.format(var)] = df_firm_fam_info[mask].groupby(['firm_agg_id', 'year_agg'])['docdb_family_id'].nunique()
    # Count number of families with some IPC/CPC codes information but no clean/dirty type
    mask = (df_firm_fam_info['CPC'].notnull() | df_firm_fam_info['IPC'].notnull()) & df_firm_fam_info['Type'].isnull()
    df_firmyear['{}_otherCPCIPC'.format(var)] = df_firm_fam_info[mask].groupby(['firm_agg_id', 'year_agg'])['docdb_family_id'].nunique()
    # Type, sector and subsector aggregation
    # First, fill na
    df_firm_fam_info['Type'] = df_firm_fam_info['Type'].fillna('')
    df_firm_fam_info['Sector'] = df_firm_fam_info['Sector'].fillna('')
    df_firm_fam_info['Sub-sector'] = df_firm_fam_info['Sub-sector'].fillna('')
    df_firm_fam_info['BatTrans'] = df_firm_fam_info['BatTrans'].fillna('')
    df_firmyear = clean_grey_dirty(df_firmyear, df_firm_fam_info, var)
    df_firmyear = nottransport(df_firmyear, df_firm_fam_info, var)
    df_firmyear = subsectors(df_firmyear, df_firm_fam_info, var)
    df_firmyear = battery_transport_vs_nontransport_applications(df_firmyear, df_firm_fam_info, var)
    df_firmyear = cleancars(df_firmyear, df_firm_fam_info, var)
    df_firmyear = df_firmyear.fillna(0).astype(int)
    return df_firmyear


def clean_grey_dirty(df_firmyear, df_firm_fam_info, var):
    LOGGER.info('                   start clean_dirty')
    # Count number of families with a code of type X where X is clean, grey, or dirty
    # First Clean up type variable
    # e.g., df_firm_fam_info[df_firm_fam_info['Type'] == 'Clean/Grey']
    # these can come from Y02T70. mitigation air/maritime - lots of stuffs in there are really grey
    df_firm_fam_info['Type'] = df_firm_fam_info['Type'].str.replace('Clean/Grey', 'Grey')
    df_firm_fam_info['Type'] = df_firm_fam_info['Type'].apply(lambda x: ','.join(sorted(set([i.strip() for i in x.split(',')]))))
    # Construct exclusive types
    dictypes = {'Dirty,Grey': 'Grey', 'Clean,Grey': 'Grey', 'Clean,Dirty': 'Grey', 'Clean,Dirty,Grey': 'Grey'}
    df_firm_fam_info['Type_exclusive'] = df_firm_fam_info['Type'].apply(lambda x: dictypes[x] if x in dictypes.keys() else x)
    techtype = 'Clean'
    for techtype in ['Clean', 'Grey', 'Dirty']:
        mask = df_firm_fam_info['Type'].str.contains(techtype)
        techtypename = techtype
        df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, techtypename, var)
        mask = df_firm_fam_info['Type_exclusive'] == techtype
        techtypename = techtype + '_excl'
        df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, techtypename, var)
    return df_firmyear


def nottransport(df_firmyear, df_firm_fam_info, var):
    LOGGER.info('                   start nottransport')
    # Count number of families with codes only in sectors other than transportation (i.e., electricity, industry...)
    mask = ~df_firm_fam_info['Sector'].str.contains('Transport') & (df_firm_fam_info['Sector'] != '')
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_NotTransport', var)
    return df_firmyear


def subsectors(df_firmyear, df_firm_fam_info, var):
    LOGGER.info('                   start subsectors')
    # Count number of families per sub-sectors
    df_firm_fam_info = cleaning_subsectors(df_firm_fam_info)
    df_firm_fam_info = cleaning_subsectors_wH2(df_firm_fam_info)
    dict_subsector = dict([(k.lower(), v) for k, v in dict_subsector_shortnames.items()])
    list_subsectors_exclusive = ['batteries', 'biofuels', 'car efficiency', 'electric vehicles', 'enabling technologies', 'energy storage', 'fuel cells', 'hybrid vehicles', 'hydrogen', 'ice efficiency', 'internal combustion engine', 'mitigation air', 'mitigation maritime', 'mitigation rail']
    techsubsector = list_subsectors_exclusive[0]
    for techsubsector in list_subsectors_exclusive:
        sectorshortname = dict_subsector[techsubsector]
        mask = df_firm_fam_info['Sub-sector'].str.contains(techsubsector)
        df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_{}'.format(sectorshortname), var)
        mask = df_firm_fam_info['Sub-sector_exclusive'].str.contains(techsubsector)
        df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_{}_excl'.format(sectorshortname), var)
        if techsubsector in ['batteries', 'fuel cells', 'electric vehicles', 'hybrid vehicles']:
            # other subsector not impacted by the change
            mask = df_firm_fam_info['Sub-sector_exclusive_wH2'].str.contains(techsubsector)
            df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_{}_excl_wH2'.format(sectorshortname), var)
    mask = df_firm_fam_info['Sub-sector_exclusive'] == 'batteries,fuel cells'
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_bothbatfc', var)
    mask = df_firm_fam_info['Sub-sector_exclusive_wH2'] == 'batteries,fuel cells'
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_bothbatfc_wH2', var)
    # Other important subsectors to capture?
    list_subsectors = [i for k in df_firm_fam_info['Sub-sector'].tolist() for i in k.split(',')]
    list_subsectors = [i for i in list_subsectors if i != '']
    list_subsectors = 100 * pd.Series(list_subsectors).value_counts() / len(list_subsectors)
    list_subsectors = list_subsectors[~list_subsectors.index.isin(list_subsectors_exclusive)]
    list_subsectors = list_subsectors[list_subsectors > 0.1]
    for techsubsector in list(list_subsectors.index):
        sectorshortname = dict_subsector[techsubsector]
        mask = df_firm_fam_info['Sub-sector'].str.contains(techsubsector)
        df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_{}'.format(sectorshortname), var)
    return df_firmyear


def battery_transport_vs_nontransport_applications(df_firmyear, df_firm_fam_info, var):
    LOGGER.info('                   start battery_transport_vs_nontransport_applications')
    # non exclusive
    sectorshortname = 'BatTrans'
    maskBat = df_firm_fam_info['Sub-sector'].str.contains('batteries')
    mask = maskBat & ((df_firm_fam_info['BatTrans'] == 'Trans') | (df_firm_fam_info['BatTrans'] == 'NonTrans,Trans'))
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_{}'.format(sectorshortname), var)
    sectorshortname = 'BatNotTr'
    mask = maskBat & (df_firm_fam_info['BatTrans'] == 'NonTrans')
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_{}'.format(sectorshortname), var)
    # exclusive
    sectorshortname = 'BatTrans'
    maskBatExl = df_firm_fam_info['Sub-sector_exclusive'].str.contains('batteries')
    mask = maskBatExl & (((df_firm_fam_info['BatTrans'] == 'Trans')) | ((df_firm_fam_info['BatTrans'] == 'NonTrans,Trans')))
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_{}_excl'.format(sectorshortname), var)
    sectorshortname = 'BatNotTr'
    mask = maskBatExl & (df_firm_fam_info['BatTrans'] == 'NonTrans')
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_{}_excl'.format(sectorshortname), var)
    return df_firmyear


def cleaning_subsectors(df_firm_fam_info):
    # Construct exclusive subsector categories: i.e. a family cannot be both EV and hybrid. if it s hybrid, we count it under hybrid
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector']
    # for batteries and fuel cells, let's not make things exclusive. i.e. a family could be both batt and fc
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'batteries,fuel cells' if ('fuel cells' in x) & ('batteries' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'batteries' if ('batteries' in x) & ('fuel cells' not in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'fuel cells' if ('batteries' not in x) & ('fuel cells' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'hybrid vehicles' if ('hybrid vehicles' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'electric vehicles' if ('electric vehicles' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'hydrogen' if ('hydrogen' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'energy storage' if ('energy storage' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'biofuels' if ('biofuels' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'ice efficiency' if ('ice efficiency' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'car efficiency' if ('car efficiency' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'enabling technologies' if ('enabling technologies' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'internal combustion engine' if ('internal combustion engine' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'mitigation air' if ('mitigation air' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'mitigation maritime' if ('mitigation maritime' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive'] = df_firm_fam_info['Sub-sector_exclusive'].apply(lambda x: 'mitigation rail' if ('mitigation rail' in x) else x)
    # df_firm_fam_info['Sub-sector_exclusive'].value_counts()
    return df_firm_fam_info



def cleaning_subsectors_wH2(df_firm_fam_info):
    # Construct exclusive subsector categories: i.e. a family cannot be both EV and hybrid. if it s hybrid, we count it under hybrid
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector']
    # for batteries and fuel cells, let's not make things exclusive. i.e. a family could be both batt and fc
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'batteries,fuel cells' if ('batteries' in x) & (('fuel cells' in x) or ('hydrogen' in x)) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'batteries' if ('batteries' in x) & ('fuel cells' not in x) & ('hydrogen' not in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'fuel cells' if ('batteries' not in x) & (('fuel cells' in x) or ('hydrogen' in x)) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'hybrid vehicles' if ('hybrid vehicles' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'electric vehicles' if ('electric vehicles' in x) else x)
    # df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'hydrogen' if ('hydrogen' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'energy storage' if ('energy storage' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'biofuels' if ('biofuels' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'ice efficiency' if ('ice efficiency' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'car efficiency' if ('car efficiency' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'enabling technologies' if ('enabling technologies' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'internal combustion engine' if ('internal combustion engine' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'mitigation air' if ('mitigation air' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'mitigation maritime' if ('mitigation maritime' in x) else x)
    df_firm_fam_info['Sub-sector_exclusive_wH2'] = df_firm_fam_info['Sub-sector_exclusive_wH2'].apply(lambda x: 'mitigation rail' if ('mitigation rail' in x) else x)
    return df_firm_fam_info


def cleancars(df_firmyear, df_firm_fam_info, var):
    LOGGER.info('                   start cleancars')
    # Count number of families related to clean cars
    cleancars = conventions_names_colors.get_list_of_subsectors_in_cleancars()
    cleancars = '|'.join(cleancars)
    # i.e., cleancars = 'electric vehicles|fuel cells|energy storage|hybrid vehicles|hydrogen|batteries'
    # Non Exclusive
    mask = df_firm_fam_info['Sub-sector'].str.contains(cleancars, regex=True)
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_CleanCar', var)
    mask0 = mask & (df_firm_fam_info['ifTriadic'] == True)
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask0, '_CleanCar_Triad', var)
    mask0 = mask & (df_firm_fam_info['NbrCountries'] > 1)
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask0, '_CleanCar_2Ctry', var)
    # Exclusive
    mask = df_firm_fam_info['Sub-sector_exclusive'].str.contains(cleancars, regex=True)
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask, '_CleanCar_excl', var)
    mask0 = mask & (df_firm_fam_info['ifTriadic'] == True)
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask0, '_CleanCar_excl_Triad', var)
    mask0 = mask & (df_firm_fam_info['NbrCountries'] > 1)
    df_firmyear = calculate_count_and_weigthedcount(df_firmyear, df_firm_fam_info, mask0, '_CleanCar_excl_2Ctry', var)
    return df_firmyear


def calculate_weights(df_agg_corres, df_firmyear, df_bvdid_docdbid, aggregation_level, namefile):
    """
    the weights will indicate to what extent oem i has absorbed the knowledge of a family filed by a bvdid that used to belong to the oem
    more explanations: each firm_agg_id (e.g., oem level 1) regroups several bvdids in a given year
    firms/subsidiaries are acquired and sold off over time. e.g., GM acquires firm A in 2008 and sells it off in 2012.
    During the time that firms are "joint", they assimilate parts of each other's knowledge stock.
    For example, say firm A stayed 4 years as an integral part of GM. It is likely that GM has absorbed
    a good part of the IP (that may well have been the reason why firm A was acquired in the first place...).
    Hence, it may be too extreme to assume that, in 2012, when they separate, GM goes back to having none of the
    knowledge stock of A. Similarly, firm A would have benefited from the 4 years spent with GM by assimilating
    some of GM's knowledge stock.
    We say that, after separation, the stock of a firm is equal to its stock plus a fraction of the other firm's stock.
    We propose that this fraction w be a function of the relative size of the two firms' patent portfolio. Mainly:
    --> Stock in 2012 for GM: Stock of GM + w * Stock of A
    with w = size_GM / (size_GM + size_A)
    For example, say GM has 9 patents and firm A one patent, w = 9 / 10, i.e. 90%.
    i.e. GM would absorb most of the knowledge stock.
    This formulation implies that firm A would absorb 10% of GM.
    And that two firms of equal size would absorb 50% of each other.
    THIS FUNCTION calculates the fraction w = the share of families that bvdid i is associated with in a given year
    """
    LOGGER.info('       start calculate_weights')
    weights = df_agg_corres[list({'Year', aggregation_level, 'bvdid'})].drop_duplicates()
    weights = weights.merge(df_firmyear.reset_index()[['year_agg', 'firm_agg_id', 'Count']].drop_duplicates(), left_on=['Year', aggregation_level], right_on=['year_agg', 'firm_agg_id'], how='left')
    weights = weights.drop(columns=['year_agg', 'firm_agg_id']).rename(columns={'Count': 'Count_OEM'})
    df_firm_fam_weigths = df_bvdid_docdbid.rename(columns={'bvdid': 'firm_agg_id', 'earliest_filing_year': 'year_agg'})
    df_firmyear_weigths = counting_families(df_firm_fam_weigths, namefile, 'Count')
    weights = weights.merge(df_firmyear_weigths.reset_index()[['year_agg', 'firm_agg_id', 'Count']].drop_duplicates(), left_on=['Year', 'bvdid'], right_on=['year_agg', 'firm_agg_id'], how='left')
    weights = weights.drop(columns=['year_agg', 'firm_agg_id']).rename(columns={'Count': 'Count_bvdid'})
    weights['weight'] = 1 - weights['Count_bvdid'] / weights['Count_OEM']
    # i.e. if Count_bvdid/Count_OEM = 0.01 (1%), the oem absorbs 99% of the family of the bvdid
    weights = weights.drop(columns=['Count_OEM', 'Count_bvdid'])
    weights = weights.rename(columns={'Year': 'earliest_filing_year'})
    return weights


def add_cumulative_stocks(df_agg_corres, df_bvdid_docdbid, df_firmyear, aggregation_level, namefile):
    LOGGER.info('   start add_cumulative_stocks')
    # Reducing sample to run locally
    """ 
    GOAL: construct a dataframe corres_cum_stock which indicates which families to aggregate in a given 'year_agg', 'firm_agg_id'
    i.e., since we're interested in stock, we will want to aggregate all the families that were filed in the past
    SIMPLE CASE: say we're only interested in aggregating at bvdid level (e.g., for suppliers, or firms behind ipc cpc transpo)
    Then, we just want to construct a correspondence corres_cum_stock that for a given year t, list all the docdbids filed before year t
    this is the dataframe currently_owned.
    the other dataframe, previously_owned, is not relevant for such simple case and will return a null value
    but that does not break the code 
    CASE WITH SUBSIDIAIRIES that change over time:
    Things get a bit more complicated because over time, bvdids become part of and/or leave a particular OEM.
    E.g., when calculating the stock of families at year t, we may also want to count to some extent the families 
    of bvdid j which left/was sold owned by the OEM up to year t1. 
    We calculate the weights that defines the extent to which families of old subsidiaries contribute to stock at time t 
    See function calculate_weights for more details
    finally, we also want to discount families the further in the past they were filed
    we do this here. see below 'weight_discounted'
    """
    # First deal with those bvdids that belong to oem in year t
    # Collect all the docdb ids of the bvdids that are part of the OEM in year t
    LOGGER.info('               start currently_owned')
    currently_owned = df_agg_corres[list({'Year', aggregation_level, 'bvdid'})].drop_duplicates().merge(df_bvdid_docdbid, on='bvdid', how='left')
    # Keep families that are older than year t only
    currently_owned = currently_owned[currently_owned['earliest_filing_year'] <= currently_owned['Year']]
    currently_owned['weight'] = 1   # set to 1: all families here belong to the oem to the same extent (we discount by age lower down)
    currently_owned = currently_owned[['Year', aggregation_level, 'docdb_family_id', 'earliest_filing_year', 'weight']]
    # Second, collect the docdb ids of bdvids that used to be owned by the OEM
    # and discount how those families contribute to the knowledge stock at time t
    # To do this, first, calculate the last year when the OEM owned firm j.
    LOGGER.info('               start previously_owned')
    previously_owned = df_agg_corres.groupby(list({aggregation_level, 'bvdid'}))['Year'].max().rename('year_last').reset_index()
    previously_owned = df_agg_corres[['Year', aggregation_level]].drop_duplicates().merge(previously_owned, on=aggregation_level)
    # keep only observations corresponding to years when bvdid j no longer belongs to OEMs
    previously_owned = previously_owned[previously_owned['Year'] > previously_owned['year_last']]
    previously_owned = previously_owned.merge(df_bvdid_docdbid, on='bvdid', how='left')
    # keep only families that were filed before the year last
    previously_owned = previously_owned[previously_owned['earliest_filing_year'] <= previously_owned['year_last']]
    # calculate a weight that indicates how important a particular bvdid is for a given 'year_agg', 'firm_agg_id'
    weights = calculate_weights(df_agg_corres, df_firmyear, df_bvdid_docdbid, aggregation_level, namefile)
    previously_owned = previously_owned.merge(weights, left_on=list({aggregation_level, 'bvdid', 'earliest_filing_year'}), right_on=list({aggregation_level, 'bvdid', 'earliest_filing_year'}), how='left')
    previously_owned = previously_owned[['Year', aggregation_level, 'docdb_family_id', 'earliest_filing_year', 'weight']]
    # combine
    LOGGER.info('               combine')
    corres_cum_stock = pd.concat([currently_owned, previously_owned])
    # Perpetual inventory method: K_t = PatentCount_t + (1-delta) X K_{t-1]
    # Depreciation rate: delta = 20 % (the usual value in the literature)
    corres_cum_stock['weight_discounted'] = corres_cum_stock.apply(lambda row: row['weight'] * pow(0.8, row['Year'] - row['earliest_filing_year']), axis=1)
    corres_cum_stock = corres_cum_stock[['Year', aggregation_level, 'docdb_family_id', 'weight_discounted']]
    corres_cum_stock = corres_cum_stock.rename(columns={aggregation_level: 'firm_agg_id', 'Year': 'year_agg'})
    # Calculate cumulative counts
    corres_cum_stock = counting_families(corres_cum_stock, namefile, 'Stock')
    # we're using the same counting_families family function as before but the 'Stock' string ensures it can count weights
    corres_cum_stock = corres_cum_stock.drop(columns='MeanFamSize')
    # Combine normal counts and cumulative counts
    df_firmyear = df_firmyear.merge(corres_cum_stock, on=['firm_agg_id', 'year_agg'], how='outer')
    return df_firmyear


def add_names(df_firmyear, namefile):
    LOGGER.info('   start add_names')
    df_firmyear = df_firmyear.reset_index()
    if namefile == 'subsidiaries':
        data_names = pd.read_csv(PATHS.dropbox / 'Data_outputted/A_AutoIndustry/OEM_and_Subsidiaries.csv', usecols=['Sub_BvDID', 'Sub_Name']).drop_duplicates()
        data_names = data_names.rename(columns={'Sub_BvDID': 'firm_agg_id', 'Sub_Name': 'name'})
        data_names = data_names[data_names['firm_agg_id'].notnull()]
    elif 'suppliers' in namefile:
        data_names = pd.read_csv(PATHS.dropbox / 'Data_outputted/B_FactsetVariables/suppliers_ids.csv', usecols=['bvdid', 'bvdNAMEmatch']).drop_duplicates()
        data_names = data_names.rename(columns={'bvdid': 'firm_agg_id', 'bvdNAMEmatch': 'name'})
        data_names = data_names[data_names['firm_agg_id'].notnull()]
    elif namefile == 'OtherFirms_in_ipc_cpc_transpo':
        data_naics = pd.read_csv(PATHS.dropbox / 'Data_outputted/temp/FirmsNaics_of_Families_of_ipc_cpc_transpo.csv')
        data_names = pd.read_csv(PATHS.dropbox / 'Data_outputted/temp/FirmsNames_of_Families_of_ipc_cpc_transpo.csv')
        data_names = data_names.merge(data_naics, on='bvdid', how='outer')
        data_names = data_names.rename(columns={'bvdid': 'firm_agg_id', 'NAME': 'name'})
    else:
        data_names = pd.read_csv(PATHS.dropbox / 'Data_outputted/A_AutoIndustry/OEM_and_Subsidiaries.csv', usecols=['OEM_Level1_ID', 'Level1_OrbisName', 'OEM_Level2_ID', 'Level2_OrbisName']).drop_duplicates()
        data_names = data_names.rename(columns={'OEM_Level1_ID': 'firm_agg_id', 'OEM_Level2_ID': 'firm_agg_id', 'Level1_OrbisName': 'name', 'Level2_OrbisName': 'name'})
        data_names = pd.concat([data_names.iloc[:, :2], data_names.iloc[:, 2:]])
    # NOTE: next we drop duplicated so that each bvdid exactly has one naics. we keep the first naics code.
    data_names = data_names.drop_duplicates('firm_agg_id')
    df_firmyear = data_names.merge(df_firmyear, on='firm_agg_id', how='right')
    return df_firmyear


def get_docdbids_of_bvdids(namefile):
    # outputs a df that has docdbids, bvdids and year info (merging patstat orbus correspondence with family files)
    # namefile = 'suppliers' / 'ipc_cpc_transpo' / 'nace_motor' / 'OEMs'
    # IMPORT
    PatstatLinks = pd.read_csv(PATHS.privatedata / 'Patstat_orbis_correspondence/Orbis_PATSTAT_updatePEM_anonymized.csv')
    if namefile == 'OEMs':
        df_bvdid_docdbid = []
        df_bvdid_docdbid.append(pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Patents_of_OEMs.csv', usecols=['docdb_family_id', 'appln_id', 'earliest_filing_year']))
        df_bvdid_docdbid.append(pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Patents_of_subsidiaries.csv', usecols=['docdb_family_id', 'appln_id', 'earliest_filing_year']))
        df_bvdid_docdbid = pd.concat(df_bvdid_docdbid).drop_duplicates()
    else:
        df_bvdid_docdbid = pd.read_csv(PATHS.dropbox / f'Data_outputted/C_PatentVariables/Patents_of_{namefile}.csv', usecols=['docdb_family_id', 'appln_id', 'earliest_filing_year'])
    # Merge via docdbids
    df_fam_via_docdb = df_bvdid_docdbid.merge(PatstatLinks[['docdb_family_id', 'bvdid']].drop_duplicates(), on='docdb_family_id', how='left')
    df_fam_via_docdb = df_fam_via_docdb[['bvdid', 'docdb_family_id', 'earliest_filing_year']].drop_duplicates()
    df_fam_via_docdb = df_fam_via_docdb[df_fam_via_docdb['bvdid'].notnull()]
    # Merge via applnids
    df_fam_via_applnid = df_bvdid_docdbid.merge(PatstatLinks[['appln_id', 'bvdid']].drop_duplicates(), on='appln_id', how='left')
    df_fam_via_applnid = df_fam_via_applnid[['bvdid', 'docdb_family_id', 'earliest_filing_year']].drop_duplicates()
    df_fam_via_applnid = df_fam_via_applnid[df_fam_via_applnid['bvdid'].notnull()]
    # Combine
    df_bvdid_docdbid = pd.concat([df_fam_via_docdb, df_fam_via_applnid]).drop_duplicates()
    # Keep only bvdids that correspond to OEMs/suppliers...
    if namefile == 'suppliers':
        Suppliers = pd.read_csv(PATHS.dropbox / 'Data_outputted/B_FactsetVariables/suppliers_ids.csv')
        df_bvdid_docdbid = df_bvdid_docdbid[df_bvdid_docdbid['bvdid'].isin(Suppliers['bvdid'].unique())]
    elif namefile == 'OEMs':
        OEMs = pd.read_csv(PATHS.dropbox / 'Data_outputted/A_AutoIndustry/OEM_and_Subsidiaries.csv').drop_duplicates()
        list_oemsandsubsi = pd.concat([OEMs['Level1_bvdid'].dropna(), OEMs['Level2_bvdid'].dropna(), OEMs['Sub_BvDID'].dropna()]).unique().tolist()
        df_bvdid_docdbid = df_bvdid_docdbid[df_bvdid_docdbid['bvdid'].isin(list_oemsandsubsi)]
    return df_bvdid_docdbid


def main(df_agg_corres, df_bvdid_docdbid, aggregation_level='OEM_Level1_ID', namefile=''):
    """
    df_bvdid_docdbid: is a df with 3 columns: bvdids, docdbids and earliest filing year
        there should be all the docdbids of the families connected to bvdids of interest
    df_agg_corres: provide the underlying structure for how docdbids must be aggregated over the years
        i.e., a df with 3 columns: firm_agg_id, year_agg and docdb_family_id
        this can be as a straightforward as deciding to aggregate at the bvdid level
        in such a case, firm_agg_id = bvdid and there would no difference between df_bvdid_docdbid and df_agg_corres
        For OEMs, this is more complicated because we regroup several bvdids into OEM_Level1_ID or OEM_Level2_ID
    """
    LOGGER.info('   firm_year_aggregation.py for: {}'.format(namefile))
    LOGGER.info('   start counting_families')
    df_firm_year_docdbid = df_agg_corres[[aggregation_level, 'docdb_family_id', 'earliest_filing_year']].drop_duplicates()
    df_firm_year_docdbid = df_firm_year_docdbid.rename(columns={aggregation_level: 'firm_agg_id', 'earliest_filing_year': 'year_agg'})
    df_firmyear = counting_families(df_firm_year_docdbid, namefile, 'Count')
    LOGGER.info('   start add_cumulative_stocks')
    df_firmyear = add_cumulative_stocks(df_agg_corres, df_bvdid_docdbid, df_firmyear, aggregation_level, namefile)
    df_firmyear = add_names(df_firmyear, namefile)
    df_firmyear = df_firmyear.rename(columns={'firm_agg_id': aggregation_level, 'year_agg': 'earliest_filing_year'})
    df_firmyear.to_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Panel_{}_FamInfo.csv'.format(namefile), index=False)
    LOGGER.info('Panel_{}_FamInfo.csv Saved'.format(namefile))
